import torch.nn as nn
import torch.nn.functional as F
import torch

from . import ddn, ddn_loss
from pcdet.models.model_utils.basic_block_2d import BasicBlock2D

from pcdet.datasets.processor.data_processor import VoxelGeneratorWrapper


class DepthFFN(nn.Module):

    def __init__(self, model_cfg, downsample_factor):
        """
        Initialize frustum feature network via depth distribution estimation
        Args:
            model_cfg: EasyDict, Depth classification network config
            downsample_factor: int, Depth map downsample factor
        """
        super().__init__()
        self.model_cfg = model_cfg
        self.disc_cfg = model_cfg.DISCRETIZE
        self.downsample_factor = downsample_factor

        # Create modules
        # self.ddn = ddn.__all__[model_cfg.DDN.NAME](
        #     num_classes=self.disc_cfg["num_bins"] + 1,
        #     backbone_name=model_cfg.DDN.BACKBONE_NAME,
        #     **model_cfg.DDN.ARGS
        # )
        self.ddn = ddn.DDNTemplate(
            num_classes=self.disc_cfg["num_bins"] + 1,
            **model_cfg.DDN.ARGS
        )


        ##### orignal caddn  #####
        self.channel_reduce = BasicBlock2D(**model_cfg.CHANNEL_REDUCE)
        self.ddn_loss = ddn_loss.__all__[model_cfg.LOSS.NAME](
            disc_cfg=self.disc_cfg,
            downsample_factor=downsample_factor,
            **model_cfg.LOSS.ARGS
        )
        self.forward_ret_dict = {}


        self.fru3D = nn.Sequential(
                                nn.Conv3d(16, 16, 3, 1, 1),
                                nn.GroupNorm(2, 16),
                                nn.ReLU(inplace=True),
                                nn.Conv3d(16, 16, 3, 1, 1),
                                nn.GroupNorm(2, 16),
                                nn.ReLU(inplace=True)
        )
        self.fru3D_prob = nn.Sequential(nn.Conv3d(16, 16, 3, 1, 1),
                                       nn.ReLU(inplace=True),
                                       nn.Conv3d(16, 1, 3, 1, 1))


    def get_output_feature_dim(self):
        return self.channel_reduce.out_channels


    def forward(self, batch_dict):
        """
        Predicts depths and creates image depth feature volume using depth distributions
        Args:
            batch_dict:
                images: (N, 3, H_in, W_in), Input images
        Returns:
            batch_dict:
                frustum_features: (N, C, D, H_out, W_out), Image depth features
        """
        # Pixel-wise depth classification
        images = batch_dict["images"]
        ddn_result = self.ddn(images)
        image_features = ddn_result["features"]
        depth_logits = ddn_result["logits"]

        # Channel reduce
        if self.channel_reduce is not None:
            image_features = self.channel_reduce(image_features)


        # Create image feature plane-sweep volume
        frustum_features = self.create_frustum_features(image_features=image_features,
                                                        depth_logits=depth_logits)

        frustum_features = self.fru3D(frustum_features)

        frustum_features_prob = self.fru3D_prob(frustum_features)
        frustum_features_prob = torch.clamp(torch.sigmoid(frustum_features_prob), 1e-5, 1 - 1e-5)
        batch_dict["frustum_features_prob"] = frustum_features_prob
        batch_dict["frustum_features"] = image_features.unsqueeze(2) * frustum_features_prob


        depth_probs = F.softmax(depth_logits, dim=1)
        batch_dict["depth_probs"] = depth_probs

        if self.training:
            self.forward_ret_dict["depth_maps"] = batch_dict["depth_maps"]
            self.forward_ret_dict["gt_boxes2d"] = batch_dict["gt_boxes2d"]
            self.forward_ret_dict["depth_logits"] = depth_logits
        return batch_dict


    def create_frustum_features(self, image_features, depth_logits):
        """
        Create image depth feature volume by multiplying image features with depth distributions
        Args:
            image_features: (N, C, H, W), Image features
            depth_logits: (N, D+1, H, W), Depth classification logits
        Returns:
            frustum_features: (N, C, D, H, W), Image features
        """
        channel_dim = 1
        depth_dim = 2

        # Resize to match dimensions
        image_features = image_features.unsqueeze(depth_dim)
        depth_logits = depth_logits.unsqueeze(channel_dim)

        # Apply softmax along depth axis and remove last depth category (> Max Range)
        depth_probs = F.softmax(depth_logits, dim=depth_dim)
        depth_probs = depth_probs[:, :, :-1]

        # Multiply to form image depth feature volume
        frustum_features = depth_probs * image_features

        return frustum_features

    def get_loss(self):
        """
        Gets DDN loss
        Args:
        Returns:
            loss: (1), Depth distribution network loss
            tb_dict: dict[float], All losses to log in tensorboard
        """
        loss, tb_dict = self.ddn_loss(**self.forward_ret_dict)
        return loss, tb_dict


def gen_center_depths(depth_min=2, depth_max=46.8, num_bins=80, mode='LID'):
    ind = torch.arange(num_bins)
    if mode == "UD":
        bin_size = (depth_max - depth_min) / num_bins
        center_depth = bin_size * ind + depth_min
    elif mode == "LID":
        bin_size = 2 * (depth_max - depth_min) / (num_bins * (1 + num_bins))
        center_depth = ((1 + 2 * ind) ** 2 - 1) * bin_size / 8 + depth_min
    else:
        raise NotImplementedError

    return center_depth